{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8b4bc5a3-1f6c-431d-8fe1-81921076fc5e",
   "metadata": {},
   "source": [
    "# Fine-Tune an Embedding Model\n",
    "\n",
    "Fine-tuning embedding model using HF model + sentence transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d465f9ff-6a07-4bf1-9867-5d297d18a49d",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install sentence-transformers==3.0.1 torch==2.2.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1680c03-3635-49fc-9182-490dc01b1f2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sentence_transformers\n",
    "\n",
    "print(\"torch version:\", torch.__version__)\n",
    "print(\"sentence transformers version:\", sentence_transformers.__version__)\n",
    "PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1204b30-75e4-48c6-8990-2ede687e8e11",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a265166-e6a1-4ab3-b6ed-f668cb9ae5ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer, InputExample, losses, util\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# Load embedding model\n",
    "model = SentenceTransformer(\"Alibaba-NLP/gte-base-en-v1.5\", trust_remote_code=True)\n",
    "\n",
    "# Function to print similarity score\n",
    "def get_similarity_score():\n",
    "    sentence1 = \"I love the taste of fresh apples.\"\n",
    "    sentence2 = \"Apples are rich in vitamins and fiber.\"\n",
    "    embedding1 = model.encode(sentence1)\n",
    "    embedding2 = model.encode(sentence2)\n",
    "    cosine_score = util.cos_sim(embedding1, embedding2)\n",
    "    score_number = cosine_score.item()\n",
    "    print(f\"Cosine similarity between '{sentence1}' and '{sentence2}': {score_number:.4f}\")\n",
    "    return cosine_score\n",
    "\n",
    "# Print similarity score before training\n",
    "print(\"Before training:\")\n",
    "similarity_before = get_similarity_score()\n",
    "\n",
    "train_examples = [\n",
    "    InputExample(texts=[\"I love eating apples.\", \"Apples are my favorite fruit\", \"Apple is a tech company\"]),\n",
    "    InputExample(texts=[\"Chocolate is a sweet treat loved by many.\", \"I can't resist a good piece of chocolate.\", \"Chocolate Rain was one of the most popular songs on YouTube from 2007.\"]),\n",
    "    InputExample(texts=[\"Ice cream is a refreshing dessert.\", \"I love trying different ice cream flavors.\", \"The rapper and actor Ice Cube was wearing a cream colored suit to the VMAs.\"]),\n",
    "    InputExample(texts=[\"Salad is a healthy meal option.\", \"I love a fresh, crisp salad with various vegetables.\", \"Salad Fingers is a surreal web series created by David Firth.\"]),\n",
    "]\n",
    "train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8)\n",
    "train_loss = losses.TripletLoss(model=model)\n",
    "\n",
    "# fine tune\n",
    "model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=10) \n",
    "\n",
    "print(\"After training:\")\n",
    "similarity_after = get_similarity_score()\n",
    "\n",
    "similarity_difference = similarity_after - similarity_before\n",
    "print(f\"Change in similarity score: {similarity_difference.item():4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "802738a8-a056-4845-9622-a3ad479b72c4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
